import os
import glob
import numpy as np
import pandas as pd
import tkinter as tk
from tkinter import filedialog, messagebox

import mne


# =========================
# USER SETTINGS
# =========================
FILE_GLOB_PATTERNS = ["*_manualOutliersRemoved-epo.fif", "*_cleaned-epo.fif"]

TR_LABEL = "TR"
TL_LABEL = "TL"
GR_LABEL = "GR"
GL_LABEL = "GL"

TIME_WINDOW_MS = (300.0, 500.0)

BASELINE_S = (-0.2, 0.0)
BASELINE_MODE = "ratio"  # "ratio", "db", or "none"

DECIMATE = 1
FREQS_PER_HZ = 2.0  # ~0.5 Hz step

# IMPORTANT: We start with a target cycles rule, but we CAP it to your epoch length
N_CYCLES_TARGET_MODE = "fixed"  # "fixed" or "proportional"
FIXED_N_CYCLES_TARGET = 5.0     # target; will be capped down if too long
PROPORTIONAL_DIVISOR = 2.0      # if proportional, n_cycles = freqs / PROPORTIONAL_DIVISOR

# Bands (Hz)
BANDS = {
    "theta": (4.0, 7.0),
    "alpha": (8.0, 12.0),
    "beta":  (13.0, 30.0),
}

ROI_POSTERIOR = ["POz", "P3", "P4"]
ROI_FRONTOCENTRAL = ["Fz", "Cz"]
ROI_PARIETAL = ["POz", "P3", "P4"]


# =========================
# GUI HELPERS
# =========================
def pick_folder(title: str) -> str:
    root = tk.Tk()
    root.withdraw()
    folder = filedialog.askdirectory(title=title)
    root.destroy()
    return folder

def ensure_dir(path: str):
    os.makedirs(path, exist_ok=True)

def gather_files(folder: str):
    files = []
    for pat in FILE_GLOB_PATTERNS:
        files.extend(glob.glob(os.path.join(folder, pat)))
    return sorted(list(dict.fromkeys(files)))

def subj_id_from_path(p: str) -> str:
    base = os.path.basename(p)
    for suf in ["_manualOutliersRemoved-epo.fif", "_cleaned-epo.fif",
                "-manualOutliersRemoved-epo.fif", "-cleaned-epo.fif"]:
        if base.endswith(suf):
            base = base[:-len(suf)]
            break
    if base.endswith(".fif"):
        base = base[:-4]
    return base


# =========================
# EVENT HELPERS
# =========================
def normalize_event_key(s: str) -> str:
    return str(s).strip().lower()

def find_event_key(event_id: dict, desired_label: str):
    desired = normalize_event_key(desired_label)
    for k in event_id.keys():
        if normalize_event_key(k) == desired:
            return k
    return None

def select_epochs(epochs: mne.Epochs, label: str) -> mne.Epochs:
    key = find_event_key(epochs.event_id, label)
    if key is None:
        raise ValueError(f"Event '{label}' not found. Available: {list(epochs.event_id.keys())}")
    return epochs[key]


# =========================
# MORLET POWER (WITH SAFE CYCLE CAP)
# =========================
def _make_freqs(f_lo: float, f_hi: float) -> np.ndarray:
    n = int(round((f_hi - f_lo) * FREQS_PER_HZ)) + 1
    n = max(n, 2)
    return np.linspace(f_lo, f_hi, n)

def _target_n_cycles(freqs: np.ndarray) -> np.ndarray:
    if N_CYCLES_TARGET_MODE == "fixed":
        return np.full_like(freqs, float(FIXED_N_CYCLES_TARGET), dtype=float)
    return freqs / float(PROPORTIONAL_DIVISOR)

def _cap_n_cycles_to_epoch(freqs: np.ndarray, sfreq: float, n_times: int, n_cycles: np.ndarray) -> np.ndarray:
    """
    MNE raises if any wavelet is longer than the signal.
    A conservative approximation for Morlet support length in samples is ~ 2 * sfreq * (n_cycles / freq).
    Enforce: 2 * sfreq * (n_cycles / freq) <= n_times - 1  -> n_cycles <= (n_times - 1) * freq / (2*sfreq)
    """
    max_cycles = (max(n_times - 1, 1) * freqs) / (2.0 * sfreq)
    # Keep cycles >= 1.0 (below that, wavelets get weird/noisy); cap to max_cycles
    capped = np.minimum(n_cycles, max_cycles)
    capped = np.clip(capped, 1.0, None)

    # If still impossible (very short epochs), bump low freqs slightly by letting cycles go to 1
    return capped

def compute_band_power_morlet(ep: mne.Epochs, f_lo: float, f_hi: float) -> np.ndarray:
    """
    Returns band-averaged Morlet power: (n_epochs, n_ch, n_times)
    """
    freqs = _make_freqs(f_lo, f_hi)
    sfreq = float(ep.info["sfreq"])
    n_times = len(ep.times)

    n_cycles = _target_n_cycles(freqs)
    n_cycles = _cap_n_cycles_to_epoch(freqs, sfreq, n_times, n_cycles)

    tfr = ep.compute_tfr(
        method="morlet",
        freqs=freqs,
        n_cycles=n_cycles,
        return_itc=False,
        average=False,
        decim=DECIMATE,
        verbose=False,
    )
    power = tfr.data.mean(axis=2)  # avg across freq

    if BASELINE_MODE != "none":
        power = mne.baseline.rescale(
            power,
            ep.times,
            BASELINE_S,
            mode=BASELINE_MODE,
            copy=False,
        )

    return power


# =========================
# UTILITIES
# =========================
def safe_divide(a: np.ndarray, b: np.ndarray, eps: float = 1e-12) -> np.ndarray:
    return a / np.where(np.abs(b) < eps, eps, b)

def time_mask(times_s: np.ndarray, t0_ms: float, t1_ms: float) -> np.ndarray:
    t_ms = times_s * 1000.0
    lo = min(t0_ms, t1_ms)
    hi = max(t0_ms, t1_ms)
    return (t_ms >= lo) & (t_ms <= hi)

def mean_roi_time(power_ch_time: np.ndarray, times_s: np.ndarray, ch_names: list, roi_chs: list, tmask: np.ndarray) -> float:
    idx = [ch_names.index(ch) for ch in roi_chs if ch in ch_names]
    if len(idx) == 0:
        return np.nan
    x = power_ch_time[idx, :][:, tmask]
    return float(np.nanmean(x))


# =========================
# FEATURE EXTRACTION
# =========================
def compute_condition_features(epochs: mne.Epochs, cond_label: str) -> dict:
    ep = select_epochs(epochs, cond_label)

    band_mean = {}
    for band, (f_lo, f_hi) in BANDS.items():
        p = compute_band_power_morlet(ep, f_lo, f_hi)  # (n_epochs, n_ch, n_times)
        band_mean[band] = p.mean(axis=0)               # (n_ch, n_times)

    return {
        "times_s": ep.times.copy(),
        "ch_names": ep.ch_names,
        "band_mean": band_mean
    }

def average_conditions(a: dict, b: dict) -> dict:
    out = {"times_s": a["times_s"], "ch_names": a["ch_names"], "band_mean": {}}
    for band in BANDS.keys():
        out["band_mean"][band] = 0.5 * (a["band_mean"][band] + b["band_mean"][band])
    return out


# =========================
# MAIN
# =========================
def main():
    lta_dir = pick_folder("Select LTA folder (epoch FIF files)")
    if not lta_dir:
        return
    hta_dir = pick_folder("Select HTA folder (epoch FIF files)")
    if not hta_dir:
        return
    out_dir = pick_folder("Select output folder")
    if not out_dir:
        return
    ensure_dir(out_dir)

    lta_files = gather_files(lta_dir)
    hta_files = gather_files(hta_dir)
    if not lta_files or not hta_files:
        messagebox.showerror("Error", "No epoch FIF files found in one or both folders.")
        return

    rows = []
    t0_ms, t1_ms = TIME_WINDOW_MS

    def process_group(files, group_name):
        for fp in files:
            sid = subj_id_from_path(fp)
            print(f"Loading {group_name}: {sid}")

            epochs = mne.read_epochs(fp, preload=True, verbose="ERROR")

            # Required events
            for lab in [TR_LABEL, TL_LABEL, GR_LABEL, GL_LABEL]:
                if find_event_key(epochs.event_id, lab) is None:
                    raise ValueError(f"{sid}: missing event '{lab}'. Available: {list(epochs.event_id.keys())}")

            feat_TR = compute_condition_features(epochs, TR_LABEL)
            feat_TL = compute_condition_features(epochs, TL_LABEL)
            feat_GR = compute_condition_features(epochs, GR_LABEL)
            feat_GL = compute_condition_features(epochs, GL_LABEL)

            feat_TST = average_conditions(feat_TR, feat_TL)
            feat_GT  = average_conditions(feat_GR, feat_GL)

            times_s = feat_TST["times_s"]
            ch_names = feat_TST["ch_names"]

            if not np.allclose(times_s, feat_GT["times_s"]):
                raise ValueError(f"{sid}: time axis mismatch between TST and GT")
            if ch_names != feat_GT["ch_names"]:
                raise ValueError(f"{sid}: channel list mismatch between TST and GT")

            tmask = time_mask(times_s, t0_ms, t1_ms)

            TA_binary = 1 if group_name.upper() == "HTA" else 0

            for condition_name, feat in [("TST", feat_TST), ("GT", feat_GT)]:
                theta = feat["band_mean"]["theta"]
                alpha = feat["band_mean"]["alpha"]
                beta  = feat["band_mean"]["beta"]
                tbr   = safe_divide(theta, beta)

                val_alpha_post = mean_roi_time(alpha, times_s, ch_names, ROI_POSTERIOR, tmask)
                val_theta_fc   = mean_roi_time(theta, times_s, ch_names, ROI_FRONTOCENTRAL, tmask)
                val_theta_par  = mean_roi_time(theta, times_s, ch_names, ROI_PARIETAL, tmask)
                val_tbr_fc     = mean_roi_time(tbr, times_s, ch_names, ROI_FRONTOCENTRAL, tmask)

                rows.extend([
                    dict(subj=sid, group=group_name, TA_binary=TA_binary,
                         condition=condition_name, roi="posterior", measure="alpha", value=val_alpha_post,
                         t_start_ms=t0_ms, t_end_ms=t1_ms),
                    dict(subj=sid, group=group_name, TA_binary=TA_binary,
                         condition=condition_name, roi="frontocentral", measure="theta", value=val_theta_fc,
                         t_start_ms=t0_ms, t_end_ms=t1_ms),
                    dict(subj=sid, group=group_name, TA_binary=TA_binary,
                         condition=condition_name, roi="parietal", measure="theta", value=val_theta_par,
                         t_start_ms=t0_ms, t_end_ms=t1_ms),
                    dict(subj=sid, group=group_name, TA_binary=TA_binary,
                         condition=condition_name, roi="frontocentral", measure="TBR", value=val_tbr_fc,
                         t_start_ms=t0_ms, t_end_ms=t1_ms),
                ])

    process_group(lta_files, "LTA")
    process_group(hta_files, "HTA")

    df_long = pd.DataFrame(rows).dropna(subset=["value"]).reset_index(drop=True)

    out_long = os.path.join(out_dir, "balanced_regression_df_long.csv")
    df_long.to_csv(out_long, index=False)

    wide = df_long.pivot_table(
        index=["subj", "group", "TA_binary", "roi", "measure", "t_start_ms", "t_end_ms"],
        columns="condition",
        values="value",
        aggfunc="mean"
    ).reset_index()

    if "TST" in wide.columns and "GT" in wide.columns:
        wide["TSTminusGT"] = wide["TST"] - wide["GT"]

    out_wide = os.path.join(out_dir, "balanced_regression_df_wide.csv")
    wide.to_csv(out_wide, index=False)

    messagebox.showinfo("Done", f"Saved:\n{out_long}\n{out_wide}")
    print("\nSaved:")
    print(" ", out_long)
    print(" ", out_wide)

if __name__ == "__main__":
    main()
